// Copyright (c) Runtime Verification, Inc. All Rights Reserved.
package org.kframework.compile

import java.util
import org.kframework.attributes.Att
import org.kframework.builtin.Sorts
import org.kframework.compile.ConfigurationInfo.Multiplicity
import org.kframework.definition.Module
import org.kframework.definition.NonTerminal
import org.kframework.definition.Production
import org.kframework.definition.Rule
import org.kframework.kore._
import org.kframework.kore.KORE.KApply
import org.kframework.kore.KORE.KLabel
import org.kframework.utils.errorsystem.KEMException
import org.kframework.Collections
import org.kframework.DirectedGraph._
import org.kframework.POSet
import scala.collection.{ IndexedSeq => _, Seq => _, _ }
import scala.jdk.CollectionConverters._

object ConfigurationInfoFromModule

class ConfigurationInfoFromModule(val m: Module) extends ConfigurationInfo {

  private val cellProductionsSet: Set[(Sort, Production)] =
    m.productions.filter(_.att.contains(Att.CELL)).map(p => (p.sort, p))
  private val cellBagProductionsSet: Set[(Sort, Production)] =
    m.productions.filter(_.att.contains(Att.CELL_COLLECTION)).map(p => (p.sort, p))

  private val cellSorts: Set[Sort]    = cellProductionsSet.map(sp => sp._1)
  private val cellBagSorts: Set[Sort] = cellBagProductionsSet.map(sp => sp._1)

  private def buildCellProductionMap(cells: Set[(Sort, Production)]): Map[Sort, Production] = {
    def buildCellProductionMap(
        _cells: Set[(Sort, Production)],
        _cellMap: Map[Sort, Production]
    ): Map[Sort, Production] = {
      if (_cells.size == 0)
        return _cellMap
      val (s, p) = _cells.head
      if (p.att.contains(Att.INTERNAL))
        return buildCellProductionMap(_cells.tail, _cellMap)
      if (_cellMap.contains(s))
        throw KEMException.compilerError("Too many productions for cell sort: " + s)
      buildCellProductionMap(_cells.tail, _cellMap.concat(Map(s -> p)))
    }
    buildCellProductionMap(cells, Map())
  }

  private val cellProductions: Map[Sort, Production] = buildCellProductionMap(cellProductionsSet)
  private val cellBagProductions: Map[Sort, Production] = buildCellProductionMap(
    cellBagProductionsSet
  )

  private val cellBagSubsorts: Map[Sort, Set[Sort]] =
    cellBagProductions.values.map(p => (p.sort, getCellSortsOfCellBag(p.sort))).toMap
  private val cellLabels: Map[Sort, KLabel] = cellProductions.view.mapValues(_.klabel.get).toMap
  private val cellLabelsToSorts: Map[KLabel, Sort] = cellLabels.map(_.swap)

  private val cellFragmentLabel: Map[Sort, KLabel] =
    m.productions
      .filter(_.att.contains(Att.CELL_FRAGMENT, classOf[Sort]))
      .map(p => (p.att.get(Att.CELL_FRAGMENT, classOf[Sort]), p.klabel.get))
      .toMap
  private val cellAbsentLabel: Map[Sort, KLabel] =
    m.productions
      .filter(_.att.contains(Att.CELL_OPT_ABSENT, classOf[Sort]))
      .map(p => (p.att.get(Att.CELL_OPT_ABSENT, classOf[Sort]), p.klabel.get))
      .toMap

  private val cellInitializer: Map[Sort, KApply] =
    m.productions
      .filter(p => (cellSorts(p.sort) || cellBagSorts(p.sort)) && p.att.contains(Att.INITIALIZER))
      .map(p => (p.sort, KApply(p.klabel.get)))
      .flatMap { case (s, app) =>
        if (cellBagSorts(s)) getCellSortsOfCellBag(s).map((_, app)) else immutable.Seq((s, app))
      }
      .toMap

  private val edges: Set[(Sort, Sort)] = cellProductions.toList.flatMap { case (s, p) =>
    p.items.flatMap {
      case NonTerminal(n, _) if cellSorts.contains(n) => List((s, n))
      case NonTerminal(n, _) if cellBagSorts.contains(n) =>
        getCellSortsOfCellBag(n).map(subsort => (s, subsort))
      case _ => List()
    }
  }.toSet

  private def getCellSortsOfCellBag(n: Sort): Set[Sort] =
    m.allSorts.filter(m.subsorts.directlyGreaterThan(n, _))

  override def getCellBagSortsOfCell(n: Sort): Set[Sort] =
    m.allSorts.filter(m.subsorts.directlyLessThan(n, _)).intersect(cellBagSorts)

  private val edgesPoset: POSet[Sort] = new POSet(edges)

  private lazy val topCells = cellSorts.diff(edges.map(_._2))

  private val sortedSorts: immutable.Seq[Sort] = Collections.immutable(edgesPoset.sortedElements())
  private val sortedEdges: immutable.Seq[(Sort, Sort)] =
    edges.toList.sortWith((l, r) => sortedSorts.indexOf(l._1) < sortedSorts.indexOf(r._1))
  val levels: Map[Sort, Int] = sortedEdges.foldLeft(topCells.map((_, 0)).toMap) {
    case (m: Map[Sort, Int], (from: Sort, to: Sort)) =>
      m + (to -> (m(from) + 1))
  }

  private lazy val mainCell = {
    val mainCells = cellProductions.filter(x => x._2.att.contains(Att.MAINCELL)).map(_._1)
    if (mainCells.size > 1)
      throw KEMException.compilerError("Too many main cells:" + mainCells)
    if (mainCells.isEmpty)
      throw KEMException.compilerError("No main cell found")
    mainCells.head
  }

  override def getLevel(k: Sort): Int         = levels.getOrElse(k, -1)
  override def isParentCell(k: Sort): Boolean = edges.exists { case (c, _) => c == k }

  override def getMultiplicity(k: Sort): Multiplicity =
    if (cellBagSubsorts.values.flatten.toSet.contains(k))
      Multiplicity.STAR
    else if (cellProductions(k).att.contains(Att.UNIT))
      Multiplicity.OPTIONAL
    else
      Multiplicity.ONE

  override def getParent(k: Sort): Sort             = edges.collectFirst { case (p, `k`) => p }.get
  override def isCell(k: Sort): Boolean             = cellSorts.contains(k)
  override def isCellLabel(kLabel: KLabel): Boolean = cellLabelsToSorts.contains(kLabel)
  override def isLeafCell(k: Sort): Boolean         = !isParentCell(k)

  override def getChildren(k: Sort): util.List[Sort] = cellProductions(k).items
    .filter(_.isInstanceOf[NonTerminal])
    .map(_.asInstanceOf[NonTerminal].sort)
    .flatMap { s =>
      if (cellBagSorts(s))
        getCellSortsOfCellBag(s).to(immutable.Seq)
      else
        immutable.Seq(s)
    }
    .asJava

  override def leafCellType(k: Sort): Sort = cellProductions(k).items.collectFirst {
    case NonTerminal(n, _) => n
  }.get

  override def getDefaultCell(k: Sort): KApply = cellInitializer(k)

  override def isConstantInitializer(k: Sort): Boolean =
    !m.productionsFor(getDefaultCell(k).klabel).exists(_.items.exists(_.isInstanceOf[NonTerminal]))

  override def getCellLabel(k: Sort): KLabel     = cellLabels(k)
  override def getCellSort(kLabel: KLabel): Sort = cellLabelsToSorts(kLabel)

  override def getCellFragmentLabel(k: Sort): KLabel = cellFragmentLabel(k)
  override def getCellAbsentLabel(k: Sort): KLabel   = cellAbsentLabel(k)

  override def getRootCell: Sort = {
    if (topCells.size > 1)
      throw KEMException.compilerError("Too many top cells for module " + m.name + ": " + topCells)
    topCells.head
  }

  override def getComputationCell: Sort = mainCell

  override def getUnit(k: Sort): KApply =
    if (getMultiplicity(k) == Multiplicity.OPTIONAL)
      KApply(KLabel(cellProductions(k).att.get(Att.UNIT)))
    else {
      val sorts = getCellBagSortsOfCell(k)
      assert(sorts.size == 1, "Too many cell bags found for cell sort: " + k + ", " + sorts)
      KApply(KLabel(cellBagProductions(sorts.head).att.get(Att.UNIT)))
    }

  override def getConcat(k: Sort): KLabel = {
    val sorts = getCellBagSortsOfCell(k)
    assert(sorts.size == 1, "Too many cell bags found for cell sort: " + k + ", " + sorts)
    cellBagProductions(sorts.head).klabel.get
  }

  override def getCellForConcat(concat: KLabel): Option[Sort] = cellSorts
    .map(s => (s, getCellBagSortsOfCell(s)))
    .filter(_._2.size == 1)
    .filter(p =>
      cellBagProductions(p._2.head).klabel.get.equals(concat) || (cellInitializer.contains(
        p._1
      ) && cellInitializer(p._1).klabel == concat)
    )
    .map(_._1)
    .headOption

  override def getCellForUnit(unitLabel: KLabel): Option[Sort] = {
    val unit = KApply(unitLabel)
    cellSorts
      .map(s => (s, getCellBagSortsOfCell(s)))
      .filter(_._2.size == 1)
      .filter(p => KApply(KLabel(cellBagProductions(p._2.head).att.get(Att.UNIT))).equals(unit))
      .map(_._1)
      .headOption
  }

  lazy val initRules: Set[Rule] = m.rules.collect { case r if r.att.contains(Att.INITIALIZER) => r }

  lazy val configVars: Set[KToken] = {
    val transformer = new FoldK[Set[KToken]] {
      override def apply(k: KToken): Set[KToken] =
        if (k.sort == Sorts.KConfigVar) Set(k) else unit
      def unit                                        = Set()
      def merge(set1: Set[KToken], set2: Set[KToken]) = set1 | set2
    }
    initRules
      .map(r => transformer.apply(r.body))
      .fold(transformer.unit)(transformer.merge)
  }

  lazy val cellProductionsFor: Map[Sort, Set[Production]] =
    m.productions
      .collect { case p if p.att.contains(Att.CELL) => p }
      .groupBy(_.sort)
      .map { case (s, ps) => (s, ps) }

}
